{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "\n",
    "file_pattern = '/mnt/tess/astronet/tfrecords-38-{0}'\n",
    "tces_file = '/mnt/tess/astronet/tces-v14-all.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "tce_table = pd.read_csv(tces_file, header=0, low_memory=False).set_index('Astro ID')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for split in ['train', 'val']:\n",
    "  filenames = tf.io.gfile.glob(file_pattern.format(split) + '/*')\n",
    "  \n",
    "  writers = {}\n",
    "  outfiles = {}\n",
    "  for i in range(1, 4):\n",
    "    dr = file_pattern.format('y' + str(i) + '-' + split)\n",
    "    tf.io.gfile.makedirs(dr)\n",
    "    outf = dr + '/00000-of-00001'\n",
    "    wr = tf.io.TFRecordWriter(outf)\n",
    "    wr.__enter__()\n",
    "    outfiles[i] = outf\n",
    "    writers[i] = wr\n",
    "\n",
    "  try:\n",
    "    counts = {i: 0 for i in range(1, 4)}\n",
    "    for filename in filenames:\n",
    "      print(filename, end='')\n",
    "      tfr = tf.data.TFRecordDataset(filename)\n",
    "      for r in tfr:\n",
    "        stat = ' '.join([f'{k}:{v}' for k, v in counts.items()])\n",
    "        print(f'\\r{filename}: {stat}', end='')\n",
    "#         print('\\r' + filename + ' '.join([(str(k) + ':' + str(v)) for k, v in counts.items()]), end='')\n",
    "        ex = tf.train.Example.FromString(r.numpy())\n",
    "        for k, f in ex.features.feature.items():\n",
    "          if k == \"astro_id\":\n",
    "            ex_id = f.int64_list.value[0]\n",
    "            yr = tce_table[tce_table.index == ex_id].Year.values[0]\n",
    "            counts[yr] += 1\n",
    "            writers[yr].write(ex.SerializeToString())\n",
    "            break\n",
    "      print('')\n",
    "  finally:\n",
    "    for _, wr in writers.items():\n",
    "      wr.__exit__()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
